{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "7d6410a1",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained('unsloth/Meta-Llama-3.1-8B-Instruct')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8192d241",
   "metadata": {},
   "outputs": [],
   "source": [
    "from streaming import MDSWriter\n",
    "from streaming.base.format.mds.encodings import Encoding, _encodings\n",
    "from streaming import LocalDataset\n",
    "import streaming\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "from glob import glob\n",
    "import os\n",
    "import json\n",
    "\n",
    "class UInt32(Encoding):\n",
    "    def encode(self, obj) -> bytes:\n",
    "        return obj.tobytes()\n",
    "\n",
    "    def decode(self, data: bytes):\n",
    "        return np.frombuffer(data, np.uint32)\n",
    "\n",
    "_encodings['uint32'] = UInt32\n",
    "\n",
    "columns = {\n",
    "    'input_ids': 'uint32',\n",
    "    'position_ids': 'uint32',\n",
    "    'attention_mask': 'uint32',\n",
    "}\n",
    "hashes = 'sha1', 'xxh64'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "c89a77fe",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mkdir: cannot create directory ‘tokenized’: File exists\r\n"
     ]
    }
   ],
   "source": [
    "!mkdir tokenized"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "75fa5332",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2032625it [00:14, 140510.96it/s]\n"
     ]
    }
   ],
   "source": [
    "combine = []\n",
    "with open('combined-instructions-language.jsonl') as fopen:\n",
    "    for l in tqdm(fopen):\n",
    "        l = json.loads(l)\n",
    "        combine.append(l)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "41639383",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('reasoning_instructions.json') as fopen:\n",
    "    reasoning = json.load(fopen)\n",
    "    \n",
    "combine.extend(reasoning)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a4780c0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import gc\n",
    "\n",
    "def collator(batch, batch_position_ids):\n",
    "    input_ids = []\n",
    "    position_ids = []\n",
    "    masks = []\n",
    "    for i in range(len(batch)):\n",
    "        l = len(batch[i])\n",
    "        input_ids.extend(batch[i])\n",
    "        position_ids.extend(batch_position_ids[i])\n",
    "        masks.append(l)\n",
    "    \n",
    "    return {\n",
    "        'input_ids': np.array(input_ids).astype(np.uint32),\n",
    "        'position_ids': np.array(position_ids).astype(np.uint32),\n",
    "        'attention_mask': np.array(masks).astype(np.uint32),\n",
    "    }\n",
    "\n",
    "def slice_and_balance(nested_list, size):\n",
    "    first = []\n",
    "    balance = []\n",
    "    current_size = 0\n",
    "\n",
    "    for sublist in nested_list:\n",
    "        if current_size < size:\n",
    "            remaining_space = size - current_size\n",
    "            if len(sublist) <= remaining_space:\n",
    "                first.append(sublist)\n",
    "                current_size += len(sublist)\n",
    "            else:\n",
    "                first.append(sublist[:remaining_space])\n",
    "                balance.append(sublist[remaining_space:])\n",
    "                current_size = size\n",
    "        else:\n",
    "            balance.append(sublist)\n",
    "    \n",
    "    return first, balance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "d98b2a7b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                       | 17/10000 [00:00<00:23, 420.14it/s]\n"
     ]
    }
   ],
   "source": [
    "block_size = 16384\n",
    "count = 0\n",
    "temp = []\n",
    "position_ids = []\n",
    "found = False\n",
    "for row in tqdm(combine[:10000]):\n",
    "    prompt = tokenizer.apply_chat_template(row, tokenize=False)\n",
    "    outputs = tokenizer(prompt, add_special_tokens = False)\n",
    "    temp.append(outputs['input_ids'])\n",
    "    position_ids.append(range(len(outputs['input_ids'])))\n",
    "    count += len(outputs['input_ids'])\n",
    "    while count >= block_size:\n",
    "        block, temp = slice_and_balance(temp, block_size)\n",
    "        block_position, position_ids = slice_and_balance(position_ids, block_size)\n",
    "        count = count - block_size\n",
    "        found = True\n",
    "        break\n",
    "    if found:\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "39327baf",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "687"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "count"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "73d84295",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 818 µs, sys: 202 µs, total: 1.02 ms\n",
      "Wall time: 988 µs\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "o = collator(block, block_position)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "9d8c5dca",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "def loop(files, block_size = 8192):\n",
    "    rows, index = files\n",
    "    out_root = f'tokenized/tokenized-{index}'\n",
    "    os.system(f'rm -rf {out_root}')\n",
    "    count = 0\n",
    "    temp = []\n",
    "    position_ids = []\n",
    "    last_block, last_position_block = None, None\n",
    "    with MDSWriter(out=out_root, columns=columns, compression=None, hashes=hashes) as out:\n",
    "        for row in tqdm(rows):\n",
    "            prompt = tokenizer.apply_chat_template(row, tokenize=False)\n",
    "            outputs = tokenizer(prompt, add_special_tokens = False)\n",
    "            temp.append(outputs['input_ids'])\n",
    "            position_ids.append(range(len(outputs['input_ids'])))\n",
    "            count += len(outputs['input_ids'])\n",
    "            while count >= block_size:\n",
    "                block, temp = slice_and_balance(temp, block_size)\n",
    "                block_position, position_ids = slice_and_balance(position_ids, block_size)\n",
    "                count = count - block_size\n",
    "                o = collator(block, block_position)\n",
    "                last_block = block\n",
    "                last_position_block = block_position\n",
    "                out.write(o)\n",
    "                \n",
    "        block, _ = slice_and_balance(last_block, block_size - count)\n",
    "        block_position, _ = slice_and_balance(last_position_block, block_size - count)\n",
    "\n",
    "        block.extend(temp)\n",
    "        block_position.extend(position_ids)\n",
    "\n",
    "        o = collator(block, block_position)\n",
    "        if len(o['input_ids']) == block_size:\n",
    "            out.write(o)\n",
    "            return o"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "234bc909",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████| 10000/10000 [00:13<00:00, 751.30it/s]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'input_ids': array([  9188,   2308,    276, ...,    276,     13, 128009], dtype=uint32),\n",
       " 'position_ids': array([244, 245, 246, ..., 902, 903, 904], dtype=uint32),\n",
       " 'attention_mask': array([ 651,   73,  399,  838,  977,  889, 1046,  920,  804,  119,   87,\n",
       "         484,  905], dtype=uint32)}"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loop((combine[:10000], 0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3cc54adc",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
      "100%|█████████████████████████████████| 100000/100000 [00:41<00:00, 2417.24it/s]\n",
      "100%|█████████████████████████████████| 100000/100000 [00:39<00:00, 2536.12it/s]\n",
      "100%|█████████████████████████████████| 100000/100000 [00:45<00:00, 2197.84it/s]\n",
      "100%|█████████████████████████████████| 100000/100000 [00:45<00:00, 2184.44it/s]\n",
      "100%|█████████████████████████████████| 100000/100000 [01:16<00:00, 1304.73it/s]\n",
      "100%|██████████████████████████████████| 100000/100000 [02:05<00:00, 799.64it/s]\n",
      "100%|██████████████████████████████████| 100000/100000 [01:42<00:00, 979.54it/s]\n",
      "  9%|███▎                                | 9265/100000 [01:14<13:59, 108.04it/s]"
     ]
    }
   ],
   "source": [
    "from multiprocess import Pool\n",
    "import mp\n",
    "\n",
    "chunks = mp.chunks(combine, 100000)\n",
    "pool = Pool(15)\n",
    "pooled = pool.map(loop, chunks)\n",
    "pool.close()\n",
    "pool.join()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4ad9c23",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "folders = sorted(glob('tokenized/tokenized-*'), key = lambda x: int(x.split('-')[-1]))\n",
    "folders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a09cca84",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = LocalDataset(local = 'tokenized/tokenized-0')\n",
    "dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "353c9334",
   "metadata": {},
   "outputs": [],
   "source": [
    "!rm -rf packing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1190255",
   "metadata": {},
   "outputs": [],
   "source": [
    "with MDSWriter(out='packing', columns=columns, compression=None, hashes=hashes) as out:\n",
    "    for f in folders:\n",
    "        try:\n",
    "            dataset = LocalDataset(local=f)\n",
    "            for i in tqdm(range(len(dataset))):\n",
    "                out.write(dataset[i])\n",
    "        except Exception as e:\n",
    "            print(e)\n",
    "            pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76cc2e7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = LocalDataset('packing')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4fafc68",
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import create_repo, delete_repo\n",
    "\n",
    "try:\n",
    "    delete_repo(repo_id=\"mesolitica/malaysian-llama3.1-8k-language-multipack\", repo_type=\"dataset\")\n",
    "except:\n",
    "    pass\n",
    "create_repo(\"mesolitica/malaysian-llama3.1-8k-language-multipack\", repo_type=\"dataset\", private = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17d27ae4",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "from huggingface_hub import HfApi\n",
    "api = HfApi()\n",
    "\n",
    "api.upload_folder(\n",
    "    folder_path=\"packing\",\n",
    "    repo_id=\"mesolitica/malaysian-llama3.1-8k-language-multipack\",\n",
    "    repo_type=\"dataset\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb5471d0",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3.10",
   "language": "python",
   "name": "python3.10"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
